# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import json
import torchvision
import numpy as np
import math

from torchvision import transforms
from .datasetbase import BasicDataset
from semilearn.datasets.augmentation import RandAugment, RandomResizedCropAndInterpolation
from semilearn.datasets.utils import split_ssl_data

from semilearn.datasets.cv_datasets.common.utils import check_if_file_exists, join
from semilearn.datasets.cv_datasets.factory import DatasetFactory
from torchvision.datasets import VisionDataset
mean, std = {}, {}

mean['ava'] = [0.485, 0.456, 0.406]
std['ava'] = [0.229, 0.224, 0.225]

class AVABaseDataset(VisionDataset):
    """
    AVABaseDataset that contains all the elements in AVA datasets,
    which can be inherited into the following datasets:
    1) aesthetic binary classification datasets
    2) aesthetic score regression datasets
    3) style classification datasets

    The elements in AVABaseDataset include:
    1) all images
    2) aesthetic scores count
    3) averaged aesthetic scores
    4) aesthetic scores count distribution
    5) aesthetic classification scores
    6) style categories
    7) challenge categories
    8) content categories #TODO for now not supported
    """

    CORRUPTED_IMAGE_IDS = [
        '512522', '627334', '523555', '593733',
        '230701', '564093', '501015', '371434',
        '440774', '501064', '453704', '2129',
        '179118', '1617', '476416', '570175',
        '547917', '502377', '499068', '564307',
        '729377', '639811', '532055', '501095',
        '277832', '556798',
    ]

    def __init__(self, root, split='train', transforms=None):
        super().__init__(root=root, transforms=transforms)
        assert split in ['train', 'test'], 'Got unsupported split: `%s`' % split
        self.split = split

        self.image_root = join(self.root, 'images/images')
        self.aesthetics_split_path = join(self.root, 'aesthetics_image_lists')
        self.style_split_path = join(self.root, 'style_image_lists')
        self.ava_txt_path = join(self.root, 'AVA.txt')
        self.tags_txt_path = join(self.root, 'tags.txt')
        self.challenges_txt_path = join(self.root, 'challenges.txt')
        self.style_txt_path = join(self.root, 'style_image_lists/styles.txt')
        self.style_content_path = join(self.root, 'style_image_lists')

        check_if_file_exists(self.aesthetics_split_path)
        check_if_file_exists(self.style_split_path)
        check_if_file_exists(self.ava_txt_path)
        check_if_file_exists(self.tags_txt_path)
        check_if_file_exists(self.challenges_txt_path)
        check_if_file_exists(self.style_txt_path)
        check_if_file_exists(self.style_content_path)

        self.aesthetics_test_split = self._process_aesthetics_split('test')
        self.aesthetics_train_split = self._process_aesthetics_split('train')

        self.style_split = self._process_style_split()
        self.imageId2ava_contents, all_aesthetics_imageIds = self._process_ava_txt()
        self.imageId2style = self._process_style_content()
        self.semanticId2name = self._process_id2name_txt(self.tags_txt_path)
        self.challengeId2name = self._process_id2name_txt(self.challenges_txt_path)
        self.style2name = self._process_id2name_txt(self.style_txt_path)

        self.aesthetics_split = self._post_process_aesthetics_split(self.split,
                                                                    self.aesthetics_test_split,
                                                                    all_aesthetics_imageIds)

        self._images = None
        self._targets = None
        self.all_imageIds = self._get_all_image_list()
        self.aesthetics_split = self._post_process_image_list(self.aesthetics_split, self.all_imageIds,
                                                              self.CORRUPTED_IMAGE_IDS)
        self.style_split = self._post_process_image_list(self.style_split, self.all_imageIds, self.CORRUPTED_IMAGE_IDS)

    @property
    def images(self):
        return self._images

    @property
    def targets(self):
        return self._targets

    def __getitem__(self, index):
        image_id = self.images[index]
        image = Image.open(join(self.image_root, image_id + '.jpg')).convert('RGB')
        target = self.targets[index]

        if self.transforms:
            image = self.transforms(image)

        return image, target

    def __len__(self):
        return len(self.images)

    def _process_aesthetics_split(self, split):
        p = self.aesthetics_split_path

        split_files = os.listdir(p)
        # print("1 ===> split_files (%d): " % len(split_files), split_files)

        # remove the small scale train split
        split_files.remove('generic_ss_train.jpgl')
        # print("2 ===> split_files (%d): " % len(split_files), split_files)

        # acquire the desired split e.g., remove other split
        split_files = [f for f in split_files if split in f]
        # print("3 ===> split_files (%d): " % len(split_files), split_files)

        # get the full path
        split_files = [join(p, f) for f in split_files]
        # print("4 ===> split_files (%d): " % len(split_files), split_files)

        aesthetics_split_image_ids = []
        s = 0
        for f in split_files:
            with open(f) as _f:
                tmp = [_id.strip() for _id in _f.readlines()]
                s += len(tmp)
                # print("===> %s: %d" % (str(f), len(tmp)))
                aesthetics_split_image_ids.extend(tmp)
        # print(aesthetics_split_image_ids)
        # print("===> len: ", len(aesthetics_split_image_ids))
        # print("===> len(set): ", len(set(aesthetics_split_image_ids)))
        # print("===> s: %d" % s)
        return list(set(aesthetics_split_image_ids))

    @staticmethod
    def _post_process_aesthetics_split(split, test_split, all_aesthetics_imageIds):
        # TODO: with image exists check
        if split == 'test':
            return list(set(test_split) & set(all_aesthetics_imageIds))
        else:
            return list(set(all_aesthetics_imageIds) - set(test_split))

    def _process_style_split(self):
        f = join(self.style_split_path, '%s.jpgl' % self.split)
        style_split_image_ids = []
        with open(f) as _f:
            style_split_image_ids.extend([_.replace('\n', '') for _ in _f.readlines()])
        # print("===> len(style_split_image_ids)-%s" % self.split, len(style_split_image_ids))
        return style_split_image_ids
    def get_classification_score(self,score):
        # if score<=2:
        #     return 0

        if 0<=score<=1:
            return 0
        if 2<=score<=3:
            return 1
        if 4<=score<=5:
            return 2
        if 6<=score<=7:
            return 3
        if 8<=score<=9:
            return 4

    def _process_ava_txt(self):
        # each line like this:
        # 1 953619 0 1 5 17 38 36 15 6 5 1 1 22 1396
        # index, image id, score_cnt x 10, semantic_id x2, challenge_id
        imageId2ava_contents = {}
        imageIds = []
        with open(self.ava_txt_path) as _f:
            for line in _f.readlines():
                line = line.split(' ')
                assert len(line) == 15, 'Corrupted AVA.txt'
                # index = line[0]
                image_id = line[1]
                counts = line[2:12]
                semantic_ids = line[12:14]
                challenge_id = line[14]
                imageIds.append(image_id)
                imageId2ava_contents[image_id] = {
                    'image_id': image_id,
                    'counts': [int(_) for _ in counts],
                    'semantic_ids': [int(_) for _ in semantic_ids],
                    'challenge_id': int(challenge_id),
                }
                counts = imageId2ava_contents[image_id]['counts']
                max_index = counts.index(max(counts))
                # must start from 1, because rating is from 1 to 10
                sum_counts = sum(counts)
                sum_scores = sum([i * _ for i, _ in enumerate(counts, start=1)])
                imageId2ava_contents[image_id]['count_distribution'] = [_ / sum_counts for _ in counts]
                regression_score = sum_scores / sum_counts
                imageId2ava_contents[image_id]['regression_score'] = regression_score
                imageId2ava_contents[image_id]['classification_score'] = self.get_classification_score(max_index)

        return imageId2ava_contents, imageIds

    def _process_style_content(self):
        split = self.split
        filename = 'train.lab' if split == 'train' else 'test.multilab'
        f = join(self.style_content_path, filename)
        imageId2style = {}
        style_list = []
        with open(f) as _f:
            if split == 'train':
                style_list = _f.readlines()
                style_list = [int(_.replace('\n', '')) for _ in style_list]
            else:
                for line in _f.readlines():
                    line = line.replace('\n', '').split(' ')
                    style_list.append([i for i, v in enumerate(line, 1) if v == '1'])

        for image_id, style in zip(self.style_split, style_list):
            imageId2style[image_id] = style

        return imageId2style

    def _get_all_image_list(self):
        return [_.split('.')[0] for _ in os.listdir(self.image_root)]

    @staticmethod
    def _process_id2name_txt(path):
        id2name = {}
        with open(path) as _f:
            for line in _f.readlines():
                line = line.strip().replace('\n', '')
                line = line.split(' ', 1)
                assert len(line) == 2, 'Corrupted file: `%s`' % path
                _id, name = line[0], line[1]
                id2name[_id] = name

        return id2name

    @staticmethod
    def _post_process_image_list(image_list, all_image_list, CORRUPTED_IMAGE_IDS):
        return list(set(image_list) & set(all_image_list) - set(CORRUPTED_IMAGE_IDS))


@DatasetFactory.register('AVAAestheticClassificationDataset')
class AVAAestheticClassificationDataset(AVABaseDataset):
    """
    AVAAestheticClassificationDataset that used for binary aesthetic classification.
    The binary label is obtained by setting the label to 1 when the averaged aesthetic scores > 5
    and setting to 0 otherwise.
    """

    def __init__(self, root, split='train', transforms=None):
        super().__init__(root=root, split=split, transforms=transforms)
        self._images = self.aesthetics_split
        self._targets = [self.imageId2ava_contents[k]['classification_score'] for k in
                         self._images]

def get_ava(args, alg, name, num_labels, num_classes, data_dir='./data', include_lb_to_ulb=True):
    
    # data_dir = os.path.join(data_dir, name.lower())
    # dset = getattr(torchvision.datasets, name.upper())
    # dset = dset(data_dir, train=True, download=True)
    # data, targets = dset.data, dset.targets

    base_dataset = AVAAestheticClassificationDataset(root=data_dir, split='train')
    data = np.array(base_dataset.images)
    target=base_dataset.targets
    
    
    crop_size = args.img_size
    img_size = args.img_size
    crop_ratio = args.crop_ratio

    transform_weak = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        transforms.RandomCrop((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean['ava'], std['ava'])
    ])

    transform_strong = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        RandomResizedCropAndInterpolation((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        RandAugment(3, 10),
        transforms.ToTensor(),
        transforms.Normalize(mean['ava'], std['ava'])
    ])

    transform_val = transforms.Compose([
        transforms.Resize(math.floor(int(img_size / crop_ratio))),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean['ava'], std['ava'])
    ])


    lb_data, lb_targets, ulb_data, ulb_targets = split_ssl_data(args, data, target, num_classes, 
                                                                lb_num_labels=num_labels,
                                                                ulb_num_labels=args.ulb_num_labels,
                                                                lb_imbalance_ratio=args.lb_imb_ratio,
                                                                ulb_imbalance_ratio=args.ulb_imb_ratio,
                                                                include_lb_to_ulb=include_lb_to_ulb)
    
    lb_count = [0 for _ in range(num_classes)]
    ulb_count = [0 for _ in range(num_classes)]
    for c in lb_targets:
        lb_count[c] += 1
    for c in ulb_targets:
        ulb_count[c] += 1
    print("lb count: {}".format(lb_count))
    print("ulb count: {}".format(ulb_count))
    # lb_count = lb_count / lb_count.sum()
    # ulb_count = ulb_count / ulb_count.sum()
    # args.lb_class_dist = lb_count
    # args.ulb_class_dist = ulb_count

    if alg == 'fullysupervised':
        lb_data = data
        lb_targets = target
        # if len(ulb_data) == len(data):
        #     lb_data = ulb_data 
        #     lb_targets = ulb_targets
        # else:
        #     lb_data = np.concatenate([lb_data, ulb_data], axis=0)
        #     lb_targets = np.concatenate([lb_targets, ulb_targets], axis=0)
    
    # output the distribution of labeled data for remixmatch
    # count = [0 for _ in range(num_classes)]
    # for c in lb_targets:
    #     count[c] += 1
    # dist = np.array(count, dtype=float)
    # dist = dist / dist.sum()
    # dist = dist.tolist()
    # out = {"distribution": dist}
    # output_file = r"./data_statistics/"
    # output_path = output_file + str(name) + '_' + str(num_labels) + '.json'
    # if not os.path.exists(output_file):
    #     os.makedirs(output_file, exist_ok=True)
    # with open(output_path, 'w') as w:
    #     json.dump(out, w)

    lb_dset = BasicDataset(alg, lb_data, lb_targets, num_classes, transform_weak, False, None, False)

    ulb_dset = BasicDataset(alg, ulb_data, ulb_targets, num_classes, transform_weak, True, transform_strong, False)

    base_dataset = AVAAestheticClassificationDataset(root=data_dir, split='test')
    test_data = np.array(base_dataset.images)
    test_target=base_dataset.targets

    eval_dset = BasicDataset(alg, test_data, test_target, num_classes, transform_val, False, None, False)

    return lb_dset, ulb_dset, eval_dset
